#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2015, 2017 Ryan Sawhill Aroha <rsaw@redhat.com>
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
#    General Public License <gnu.org/licenses/gpl.html> for more details.

from __future__ import print_function
import argparse
import time
import errno
import subprocess
import os
import logging
import signal
from sys import stdin, stderr, exit
from shutil import rmtree
from stat import S_IWRITE

prog = 'rguard'
vers = {}
vers['version'] = '1.0.1'
vers['date'] = '2017/04/26'


def ret(returnCode):
    """Return True if *returnCode* is 0."""
    if returnCode == 0:
        return True


def call(cmd, showStdout=True, shell=False):
    """Execute *cmd* and return True on success."""
    if showStdout:
        rc = subprocess.call(cmd, shell=shell)
    else:
        with open(os.devnull, 'w') as n:
            rc = subprocess.call(cmd, shell=shell, stdout=n)
    if ret(rc):
        return True
    else:
        return False


class RebootGuard:
    """Object for enabling/disabling blocks that prevent shutdown, reboot, etc.

    Keyword arguments:
        - blockMethod -- The method by which shutdown is blocked (has default)

    After initialization, call guard(enforce=True) or guard(enforce=False)
    """

    def __init__(self, blockMethod='systemd_refusemanualstart'):
        """Initialize reboot-blocking method choice.

        rsaw says: Please feel free to add other methods.
        """
        possibleBlockMethods = {
            'systemd_refusemanualstart': self.reboot_guard_systemd_refusemanualstart,
            # FIXME: Wouldn't be all that useful right now, but it's worth implementing:
            ##'systemd_inhibitorlock': self.reboot_guard_systemd_inhibitorlock,
            }
        if blockMethod in possibleBlockMethods:
            self.blockMethod = possibleBlockMethods[blockMethod]
        else:
            raise ValueError("Invalid blockMethod (''); valid choices:\n  {}"
                             .format(blockMethod, ", ".join(possibleBlockMethods.keys())))

    def guard(self, enforce=True):
        """Enable or disable reboot-guard."""
        self.blockMethod(enforce)

    def reboot_guard_systemd_refusemanualstart(self, enforce=True, targets=[]):
        """Block/unblock entering targets by way of RefuseManualStart=yes/no."""
        if not targets:
            targets = [
            'poweroff',
            'reboot',
            'halt',
            # 'hibernate',
            # 'hybrid-sleep',
            # 'suspend',
            ]
        reloadDaemon = False
        for t in targets:
            t += '.target'
            if enforce:
                if self.add_systemd_unit_start_block(t):
                    reloadDaemon = True
            else:
                if self.del_systemd_unit_start_block(t):
                    reloadDaemon = True
        if reloadDaemon:
            logging.debug("Reloading systemd via: systemctl daemon-reload")
            if not ret(subprocess.call(['systemctl', 'daemon-reload'])):
                logging.error("Unexpected error reloading systemd")

    def systemd_unit_is_blocked(self, unit):
        """Return True if systemd *unit* is configured with RefuseManualStart=yes"""  
        cmd = ['systemctl', 'show', unit, '-p', 'RefuseManualStart']
        try:
            out = subprocess.check_output(cmd, encoding='utf-8')
        except:
            logging.error("Unexpected error running: {}".format(' '.join(cmd)))
            return
        if out.strip() == 'RefuseManualStart=yes':
            return True

    def add_systemd_unit_start_block(self, unit, cfgFileName='reboot-guard.conf'):
        """Return True after successfully adding a block to prevent start of *unit*."""
        if self.systemd_unit_is_blocked(unit):
            logging.debug("Unit {} already configured with RefuseManualStart=yes".format(unit))
        else:
            dir = '/run/systemd/system/{}.d'.format(unit)
            cfg = '{}/{}'.format(dir, cfgFileName)
            if self.mkdir_p(dir):
                logging.debug("Adding RefuseManualStart=yes to {}".format(cfg))
                try:
                    with open(cfg, 'w') as f:
                        f.write('[Unit]\nRefuseManualStart=yes\n')
                except:
                    logging.error("Unexpected error writing to '{}'".format(cfg))
                else:
                    logging.warning("☹  Blocked {}".format(unit))
                    return True

    def del_systemd_unit_start_block(self, unit, cfgFileName='reboot-guard.conf'):
        """Return True after successfully unblocking manual start of *unit*."""
        if not self.systemd_unit_is_blocked(unit):
            logging.debug("Unit {} already configured with RefuseManualStart=no".format(unit))
        else:
            dir = '/run/systemd/system/{}.d'.format(unit)
            cfg = '{}/{}'.format(dir, cfgFileName)
            logging.debug("Removing {}".format(cfg))
            if self.rm_rf(cfg):
                logging.warning("☻  Unblocked {}".format(unit))
                return True

    def mkdir_p(self, path):
        """Make directory *path*."""
        if call(['mkdir', '-p', path]):
            return True
        else:
            logging.error("Unexpected error making directory '{}'".format(path))

    def rm_rf(self, path):
        """Recursively remove *path*."""
        if call(['rm', '-rf', path]):
            return True
        else:
            logging.error("Unexpected error deleting path '{}'".format(path))


class ConditionChecker:
    """An object to handle arbitrary condition checks.

    Keyword arguments:
        - forbidFiles -- Forbidden files for which to check
        - requireFiles -- Required files for which to check
        - units -- Active systemd units for which to check 
        - cmds -- Process cmd names for which to check
        - cmdArgs -- Process cmd name+args for which to check
        - runCmds -- Arbitrary commands to execute (return code is checked)

    Every argument expects a list.

    After initialization, call check_conditions() which will return True only
    if all checks pass.
    """

    def __init__(self, forbidFiles=[], requireFiles=[], units=[], cmds=[], cmdArgs=[], runCmds=[]):
        """Initialize condition checks."""
        self.forbidFiles = forbidFiles
        self.requireFiles = requireFiles 
        self.units = units
        self.cmds = cmds
        self.cmdArgs = cmdArgs
        self.runCmds = runCmds
        self.lastStatus = None

    def check_conditions(self):
        """Run through all condition checks and return True if all pass."""
        if (self.test_forbidden_files() or
                self.test_required_files() or
                self.test_active_units() or
                self.test_running_cmds() or
                self.test_running_cmd_args() or
                self.test_execute_commands()):
            if self.lastStatus or self.lastStatus is None:
                logging.warning("☹  Failed a condition-check")
            else:
                logging.debug("Still failing a condition-check")
            self.lastStatus = False
        else:
            if not self.lastStatus:
                logging.warning("☻  Passed all condition-checks")
            else:
                logging.debug("Still passing all condition-checks")
            self.lastStatus = True
        return self.lastStatus

    def test_forbidden_files(self):
        if self.forbidFiles:
            for f in self.forbidFiles:
                if os.path.exists(f):
                    logging.info("✘ Fail: Forbid-file '{}' shouldn't exist".format(f))
                    return True
            logging.info("✔ Pass: No forbid-files exist")
        elif self.lastStatus is None:
            logging.debug("No forbid-file checks defined")

    def test_required_files(self):
        if self.requireFiles:
            for f in self.requireFiles:
                if not os.path.exists(f):
                    logging.info("✘ Fail: Require-file '{}' should exist".format(f))
                    return True
            logging.info("✔ Pass: All require-files exist")
        elif self.lastStatus is None:
            logging.debug("No require-file checks defined")
    
    def test_active_units(self):
        if self.units:
            for u in self.units:
                if call(['systemctl', 'is-active', u], showStdout=False):
                    logging.info("✘ Fail: Unit '{}' shouldn't be active".format(u))
                    return True
            logging.info("✔ Pass: No active units match")
        elif self.lastStatus is None:
            logging.debug("No systemd unit checks defined")

    def test_running_cmds(self):
        if self.cmds:
            for c in self.cmds:
                if call(['pgrep', '--exact', c], showStdout=False):
                    logging.info("✘ Fail: Process named '{}' shouldn't exist".format(c))
                    return True
            logging.info("✔ Pass: No running processes match commands")
        elif self.lastStatus is None:
            logging.debug("No command checks defined")

    def test_running_cmd_args(self):
        if self.cmdArgs:
            for a in self.cmdArgs:
                if call(['pgrep', '--full', '--exact', a], showStdout=False):
                    logging.info("✘ Fail: Process name+args matching '{}' shouldn't exist".format(a))
                    return True
            logging.info("✔ Pass: No running processes match command args")
        elif self.lastStatus is None:
            logging.debug("No command-args checks defined")

    def test_execute_commands(self):
        z = {True: 'zero', False: 'non-zero'}
        if self.runCmds:
            for c in self.runCmds:
                desiredOutcome = True
                if '!' in c[0]:
                    desiredOutcome = False
                    c = c[1:]
                if '@' in c[0]:
                    shell = True
                    c = c[1:]
                    logging.debug("Running shell cmd: {}".format(c))
                else:
                    shell = False
                    logging.debug("Running cmd: {}".format(c))
                    c = c.split()
                if not call(c, showStdout=False, shell=shell) == desiredOutcome:
                    logging.info("✘ Fail: Exec-command {{{}}} needs to return {}".format(c, z[desiredOutcome]))
                    return True
            logging.info("✔ Pass: All exec-command tests return as required")
        elif self.lastStatus is None:
            logging.debug("No exec-command checks defined")


class GracefulDeath:
    """Catch signals to allow graceful shutdown."""

    def __init__(self):
        self.receivedSignal = self.receivedTermSignal = False
        catchSignals = [
            1,
            2,
            3,
            10,
            12,
            15,
        ]
        for signum in catchSignals:
            signal.signal(signum, self.handler)

    def handler(self, signum, frame):
        self.lastSignal = signum
        self.receivedSignal = True
        if signum in [2, 3, 15]:
            self.receivedTermSignal = True
            

class CustomFormatter(argparse.RawDescriptionHelpFormatter):
    """This custom formatter eliminates the duplicate metavar in help lines."""

    def _format_action_invocation(self, action):
        if not action.option_strings:
            metavar, = self._metavar_formatter(action, action.dest)(1)
            return metavar
        else:
            parts = []
            if action.nargs == 0:
                parts.extend(action.option_strings)
            else:
                default = action.dest.upper()
                args_string = self._format_args(action, default)
                for option_string in action.option_strings:
                    parts.append('%s' % option_string)
                parts[-1] += ' %s'%args_string
            return ', '.join(parts)


helpExamples= """EXAMPLES:
  As mentioned above, all of the condition-checking options can be used
  multiple times. All specified checks must pass to allow shutdown.

  {0} --forbid-file /shutdown/prevented/if/this/file/exists

  {0} --require-file /shutdown/prevented/until/this/file/exists

  {0} --cmd some-cmd-name-which-when-running-prevents-shutdown
           * e.g., 'bash' or 'firefox'

  {0} --args 'syndaemon -i 1.0 -t -K -R'
           * Prevent shutdown if this exact cmd + args are found running

  {0} --run 'ping -c2 -w1 -W1 10.0.0.1'
           * Allow shutdown only if single command (ping in this case) succeeds

  {0} --run '!ping -c2 -w1 -W1 10.0.0.1'
           * Allow shutdown only if single command FAILS

  {0} --run '!findmnt /mnt'
           * Allow shutdown if nothing mounted at mountpoint

  {0} --run '@some_cmd | some_other_cmd; another_cmd'
           * Allow shutdown only if last cmd in shell succeeds
           * When using '@', full shell syntax is supported
           * e.g.: '|', '&&', '||', ';', '>', '>>', '<', etc

  {0} --run '!@lsof -i:ssh | grep -q ESTABLISHED'
           * Allow shutdown only if shell commands FAIL
           * In this example, only if there are no established ssh connections

""".format(prog)


def parse_args():
    """Parse argv into usable input."""
    # Setup argument parser
    description = "Block systemd-initiated shutdown until configurable condition checks pass"
    version = "{} v{} last mod {}".format(prog, vers['version'], vers['date'])
    epilog = (
        "{0}"
        "VERSION:\n"
        "  {1}\n"
        "  See <http://github.com/ryran/reboot-guard> to report bugs or RFEs").format(helpExamples, version)
    fmt = lambda prog: CustomFormatter(prog)
    p = argparse.ArgumentParser(
        prog=prog,
        description=description,
        add_help=False,
        epilog=epilog,
        formatter_class=fmt)
    # Add args
    g0 = p.add_argument_group(
        'SET AND QUIT',
        description="Execute a single action with no condition checks.")
    gg0 = g0.add_mutually_exclusive_group()
    gg0.add_argument(
        '-1', '--install-guard', action='store_true',
        help="Install reboot-guard and immediately exit")
    gg0.add_argument(
        '-0', '--remove-guard', action='store_true',
        help="Remove reboot-guard (if present) and immediately exit")
    g1 = p.add_argument_group(
        'CONFIGURE CONDITIONS',
        description="Establish what condition checks must pass to allow shutdown.\nEach option may be specified multiple times.")
    g1.add_argument(
        '-f', '--forbid-file', action='append', metavar='FILE',
        help="Prevent shutdown while FILE exists")
    g1.add_argument(
        '-F', '--require-file', action='append', metavar='FILE',
        help="Prevent shutdown until FILE exists")
    g1.add_argument(
        '-u', '--unit', action='append',
        help="Prevent shutdown while systemd UNIT is active")
    g1.add_argument(
        '-c', '--cmd', action='append', metavar='CMD',
        help="Prevent shutdown while CMD exactly matches at least 1 running process")
    g1.add_argument(
        '-a', '--args', action='append', metavar='ARGS',
        help="Prevent shutdown while ARGS exactly matches the full cmd+args of at least 1 running process")
    g1.add_argument(
        '-r', '--run', action='append', metavar='COMMAND',
        help="Prevent shutdown while execution of COMMAND fails; prefix COMMAND with '!' to prevent shutdown while execution succeeds (MORE: Prefix COMMAND with '@' to execute it as a shell command, e.g., to use pipes '|' or other logic; examples: '@cmd|cmd' or '!@cmd|cmd')")
    g2 = p.add_argument_group(
        'GENERAL OPTIONS',
        description="")
    g2.add_argument(
        '-h', '--help', dest='showHelp', action='store_true',
        help="Show this help message and exit")
    g2.add_argument(
        '-i', '--interval', type=float, metavar='SEC', default=60,
        help="Modify the sleep interval between condition checks (default: 60 seconds)")
    g2.add_argument(
        '-n', '--ignore-signals', action='store_true',
        help="Ignore the most common signals (HUP, INT, QUIT, USR1, USR2, TERM), in which case a graceful exit requires the next option or else SIGKILL")
    g2.add_argument(
        '-x', '--exit-on-pass', action='store_true',
        help="Exit (and remove reboot-guard) the first time all condition checks pass")
    g2.add_argument(
        '-v', '--loglevel', choices=['debug', 'info', 'warning', 'error'], default='warning',
        help="Specify minimum message type to print (default: warning)")
    g2.add_argument(
        '-t', '--timestamps', action='store_true',
        help="Enable timestamps in message output (not necessary if running as systemd service)")
    opts =  p.parse_args()
    if opts.showHelp:
        p.print_help()
        exit()
    return opts


def main():
    """Parse args, check conditions, install reboot-guard, setup infinite loop."""
    sighandler = GracefulDeath()
    opts = parse_args()
    if opts.timestamps:
        logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s', level=opts.loglevel.upper(), datefmt='%m/%d %H:%M:%S')
    else:
        logging.basicConfig(format='%(levelname)s: %(message)s', level=opts.loglevel.upper())
    if os.geteuid() > 0:
        logging.critical("Preemptive permission denied -- exiting because EUID not root")
        exit(1)
    # Handle install/remove-and-exit options
    if opts.install_guard or opts.remove_guard:
        r = RebootGuard()
        if opts.install_guard:
            r.guard(enforce=True)
        else:
            r.guard(enforce=False)
        logging.info("Exiting due to -1 or -0 option")
        exit()
    c = ConditionChecker(
        forbidFiles=opts.forbid_file,
        requireFiles=opts.require_file,
        units=opts.unit,
        cmds=opts.cmd,
        cmdArgs=opts.args,
        runCmds=opts.run)
    # Initialize reboot-guard
    r = RebootGuard()
    # Run initial check
    if c.check_conditions():
        if opts.exit_on_pass:
            logging.warning("Exiting due to passed condition checks")
            exit()
    else:
        r.guard(enforce=True)
    # Run forever-loop
    while True:
        if sighandler.receivedSignal:
            if sighandler.receivedTermSignal and not opts.ignore_signals:
                logging.warning("Gracefully exiting due to receipt of signal {}".format(sighandler.lastSignal))
                r.guard(enforce=False)
                exit()
            else:
                logging.warning("Ignoring signal {}".format(sighandler.lastSignal))
                sighandler.receivedSignal = sighandler.receivedTermSignal = False
        time.sleep(opts.interval)
        if c.check_conditions():
            r.guard(enforce=False)
            if opts.exit_on_pass:
                logging.warning("Exiting due to passed condition checks")
                exit()
        else:
            r.guard(enforce=True)


if __name__ == '__main__':
    # Execute main loop if run directly
    main()
